package pres.lnk.utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.security.KeyStore;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Calendar;

/**
 * https自动下载证书
 *
 * @author linnankun
 * @date 2022/4/25
 */
public class HttpsAutoAuthInitializer implements RestfulHttpClient.URLConnectionInitializer {
    private final static Logger logger = LoggerFactory.getLogger(HttpsAutoAuthInitializer.class);

    private static final String PASSWORD = "changeit";
    private static final String JDK_DEFAULT_CACERTS;

    private static HttpsAutoAuthInitializer instance;

    private String password = PASSWORD;
    private String certificateDir = ".";
    private String jdkDefaultCacerts = JDK_DEFAULT_CACERTS;

    static {
        final String javaHome = System.getProperty("java.home");
        File file = new File(javaHome, "lib/security/cacerts");
        if (!file.exists()) {
            file = new File(javaHome, "jre/lib/security/cacerts");
        }
        JDK_DEFAULT_CACERTS = file.getAbsolutePath();
    }

    public static HttpsAutoAuthInitializer getInstance() {
        if (instance == null) {
            instance = new HttpsAutoAuthInitializer();
        }
        return instance;
    }

    @Override
    public HttpURLConnection init(HttpURLConnection connection, RestfulHttpClient.HttpClient client) {
        String protocol = connection.getURL().getProtocol();
        if ("https".equalsIgnoreCase(protocol)) {
            HttpsURLConnection https = (HttpsURLConnection) connection;
            try {
                downloadAndLoadCertificate(https);
            } catch (Exception e) {
                throw new RuntimeException("load https certificate failed", e);
            }
            return https;
        }
        return connection;
    }

    private void downloadAndLoadCertificate(HttpsURLConnection connection) throws Exception {
        String host = connection.getURL().getHost();
        int port = connection.getURL().getPort();
        port = port <= 0 ? 443 : port;
        synchronized (host.intern()) {

            KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
            SSLContext sslContext = SSLContext.getInstance("SSL");
            File file = new File(certificateDir, host.concat(".cer"));
            if (file.exists()) {
                long time = file.lastModified();
                final Calendar instance = Calendar.getInstance();
                instance.set(Calendar.HOUR_OF_DAY, 0);
                if (time < instance.getTimeInMillis()) {
                    // 证书只保存当天，隔天需要更新一次，避免证书过期无法使用
                    file.delete();
                }
            }
            if (!file.exists()) {
                downloadCertificate(ks, sslContext, host, port, file);
            }

            try (FileInputStream input = new FileInputStream(file)) {
                ks.load(input, password.toCharArray());
                logger.debug("{} loading https certificate: {} ", host, file.getAbsoluteFile());
            }

            TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
            tmf.init(ks);

            sslContext.init(null, tmf.getTrustManagers(), null);
            SSLSocketFactory newFactory = sslContext.getSocketFactory();
            connection.setSSLSocketFactory(newFactory);
        }
    }

    private void downloadCertificate(KeyStore ks, SSLContext sslContext, String host, int port, File file) throws Exception {
        if (!new File(jdkDefaultCacerts).exists()) {
            throw new FileNotFoundException("jdk default cacerts not found:" + jdkDefaultCacerts);
        }
        try (FileInputStream input = new FileInputStream(jdkDefaultCacerts)) {
            ks.load(input, password.toCharArray());
        }
        TrustAllManager trustAllManager = new TrustAllManager();
        sslContext.init(null, new TrustManager[]{trustAllManager}, null);
        final SSLSocketFactory socketFactory = sslContext.getSocketFactory();
        final SSLSocket socket = (SSLSocket) socketFactory.createSocket(host, port);
        socket.startHandshake();

        if (trustAllManager.chain == null) {
            throw new CertificateException("获取对端服务证书失败");
        }
        if (logger.isDebugEnabled()) {
            for (int i = 0; i < trustAllManager.chain.length; i++) {
                X509Certificate cert = trustAllManager.chain[i];
                logger.debug("find out https certificate {}: {} ", (i + 1), cert.getSubjectDN());
            }
        }

        ks.setCertificateEntry(host, trustAllManager.chain[0]);
        try (FileOutputStream output = new FileOutputStream(file)) {
            ks.store(output, password.toCharArray());
            logger.info("download {} https certificate: {}", host, trustAllManager.chain[0]);
        }

    }


    /**
     * 覆盖java默认的证书验证
     */
    private static class TrustAllManager implements X509TrustManager {
        private X509Certificate[] chain;

        @Override
        public void checkClientTrusted(X509Certificate[] chain, String s)
                throws CertificateException {
        }

        @Override
        public void checkServerTrusted(X509Certificate[] chain, String s)
                throws CertificateException {
            this.chain = chain;
        }

        @Override
        public X509Certificate[] getAcceptedIssuers() {
            return null;
        }
    }

    public static void main(String[] args) throws IOException {
        String url = "https://b2b.10086.cn/b2b/main/viewNoticeContent.html?noticeBean.id=";

        final HttpsAutoAuthInitializer httpsAutoAuthInitializer = new HttpsAutoAuthInitializer();
        httpsAutoAuthInitializer.setCertificateDir("D:/");
//        httpsAutoAuthInitializer.setJdkDefaultCacerts("C:\\Program Files\\Java\\jdk1.8.0_201\\jre\\lib\\security\\cacerts");
        RestfulHttpClient.HttpResponse response = RestfulHttpClient.getClient(url)
                .addInitializer(httpsAutoAuthInitializer)
                .request(); //发起请求

        System.out.println(response.getCode());     //响应状态码
        System.out.println(response.getRequestUrl());//最终发起请求的地址
        if (response.getCode() == 200) {
            //请求成功
            System.out.println(response.getContent());  //响应内容
        }
    }

    public String getPassword() {
        return password;
    }

    public void setPassword(String password) {
        this.password = password;
    }

    public String getCertificateDir() {
        return certificateDir;
    }

    public void setCertificateDir(String certificateDir) {
        this.certificateDir = certificateDir;
    }

    public String getJdkDefaultCacerts() {
        return jdkDefaultCacerts;
    }

    public void setJdkDefaultCacerts(String jdkDefaultCacerts) {
        this.jdkDefaultCacerts = jdkDefaultCacerts;
    }
}
